package com.uber.nullaway.jdkannotations;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import com.uber.nullaway.javacplugin.NullnessAnnotationSerializer.ClassInfo;
import com.uber.nullaway.javacplugin.NullnessAnnotationSerializer.MethodInfo;
import com.uber.nullaway.javacplugin.NullnessAnnotationSerializer.TypeParamInfo;
import com.uber.nullaway.libmodel.MethodAnnotationsRecord;
import com.uber.nullaway.libmodel.StubxWriter;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.lang.reflect.Type;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * This class utilizes jdk-javac-plugin module to generate JSON files from Java source files. Using
 * the JSON files, it generates astubx files that contains the required annotation information that
 * NullAway needs.
 */
public class AstubxGenerator {

  /**
   * Contains all information that will be added to the astubx file.
   *
   * @param importedAnnotations Mapping of 'custom annotations' to their 'definition classes'.
   * @param packageAnnotations Map of 'package names' to their 'list of package-level annotations'.
   * @param typeAnnotations Map of 'type names' to their 'list of type annotations'.
   * @param methodRecords Map of 'method signatures' to their 'method annotations record'. Method
   *     annotations record consists of return value annotations and argument annotations. {@link
   *     MethodAnnotationsRecord}
   * @param nullableUpperBounds Map of fully qualified name to a set of indices of type parameters
   *     that have nullable upper bounds.
   * @param nullMarkedClasses Set of fully qualified name of NullMarked classes
   */
  public record AstubxData(
      ImmutableMap<String, String> importedAnnotations,
      Map<String, Set<String>> packageAnnotations,
      Map<String, Set<String>> typeAnnotations,
      Map<String, MethodAnnotationsRecord> methodRecords,
      Map<String, Set<Integer>> nullableUpperBounds,
      Set<String> nullMarkedClasses) {}

  /**
   * This method generates an astubx file from jdk-javac-plugin generated JSON files, which contains
   * module information.
   *
   * @param jsonDirPath The path to the directory that contains the JSON files generated by the
   *     jdk-javac-plugin.
   * @param astubxDirPath The directory path to generate the astubx file.
   */
  public static void generateAstubx(String jsonDirPath, String astubxDirPath) {
    AstubxData astubxData = getAstubxData(jsonDirPath);
    writeToAstubxFile(astubxDirPath, astubxData);
  }

  public static AstubxData getAstubxData(String jsonDirPath) {
    Map<String, List<ClassInfo>> parsed = parseJson(jsonDirPath);

    ImmutableMap<String, String> importedAnnotations =
        ImmutableMap.of(
            "NonNull", "org.jspecify.annotations.NonNull",
            "Nullable", "org.jspecify.annotations.Nullable");
    // There is no package-info.java files in jspecify/jdk that were @NullMarked so package
    // information support is skipped in jdk-javac-plugin
    Map<String, Set<String>> packageAnnotations = new HashMap<>();
    Map<String, Set<String>> typeAnnotations = new HashMap<>();
    Map<String, MethodAnnotationsRecord> methodRecords = new LinkedHashMap<>();
    Set<String> nullMarkedClasses = new LinkedHashSet<>();
    Map<String, Set<Integer>> nullableUpperBounds = new LinkedHashMap<>();

    for (Map.Entry<String, List<ClassInfo>> entry : parsed.entrySet()) {
      for (ClassInfo clazz : entry.getValue()) {
        String fullyQualifiedClassName = clazz.type();
        if (fullyQualifiedClassName.indexOf('<') != -1) {
          fullyQualifiedClassName =
              fullyQualifiedClassName.substring(0, fullyQualifiedClassName.indexOf('<'));
        }
        if (clazz.nullMarked()) {
          nullMarkedClasses.add(fullyQualifiedClassName);
        }

        // check upperbounds of type parameters
        Set<Integer> nullableUpperBoundIndices = new LinkedHashSet<>();
        for (int idx = 0; idx < clazz.typeParams().size(); idx++) {
          TypeParamInfo typeParam = clazz.typeParams().get(idx);
          for (String bound : typeParam.bounds()) {
            if (bound.contains("@org.jspecify.annotations.Nullable")
                || bound.contains("@Nullable")) {
              nullableUpperBoundIndices.add(idx);
            }
          }
        }
        if (!nullableUpperBoundIndices.isEmpty()) {
          nullableUpperBounds.put(fullyQualifiedClassName, nullableUpperBoundIndices);
        }
        getMethodRecords(clazz, fullyQualifiedClassName, methodRecords);
      }
    }
    return new AstubxData(
        importedAnnotations,
        packageAnnotations,
        typeAnnotations,
        methodRecords,
        nullableUpperBounds,
        nullMarkedClasses);
  }

  public static void writeToAstubxFile(String astubxDirPath, AstubxData astubxData) {
    // check if the astubx file directory exists
    try {
      Files.createDirectories(Paths.get(astubxDirPath));
    } catch (IOException e) {
      System.err.println("Failed to create directory: " + astubxDirPath);
      throw new RuntimeException(e);
    }
    File outputFile = new File(astubxDirPath, "output.astubx");
    try (DataOutputStream out = new DataOutputStream(new FileOutputStream(outputFile))) {
      StubxWriter.write(
          out,
          astubxData.importedAnnotations(),
          astubxData.packageAnnotations(),
          astubxData.typeAnnotations(),
          astubxData.methodRecords(),
          astubxData.nullMarkedClasses(),
          astubxData.nullableUpperBounds());
    } catch (IOException e) {
      System.err.println("Error writing JSON file: " + outputFile.getAbsolutePath());
      throw new RuntimeException(e);
    }
  }

  /**
   * This method parses the JSON files generated by the jdk-javac-plugin, and returns the
   * information as a Map from module name to information for classes in that module.
   *
   * @param jsonDirPath The path to the JSON files.
   * @return A Map from module name to information for classes in that module.
   */
  private static Map<String, List<ClassInfo>> parseJson(String jsonDirPath) {
    // get parsed JSON file
    File jsonDir = new File(jsonDirPath);

    if (!jsonDir.exists() || !jsonDir.isDirectory()) {
      throw new IllegalArgumentException(
          "JSON directory does not exist or is not a directory: " + jsonDirPath);
    }

    File[] jsonFiles = jsonDir.listFiles((dir, name) -> name.endsWith(".json"));
    if (jsonFiles == null || jsonFiles.length == 0) {
      throw new IllegalStateException("No JSON files found in: " + jsonDirPath);
    }

    Gson gson = new Gson();
    Type parsedType = new TypeToken<Map<String, List<ClassInfo>>>() {}.getType();

    // parse JSON file
    Map<String, List<ClassInfo>> parsed = new HashMap<>();
    for (File jsonFile : jsonFiles) {
      try {
        String jsonContent = Files.readString(jsonFile.toPath());
        Map<String, List<ClassInfo>> parsedShard = gson.fromJson(jsonContent, parsedType);
        for (Map.Entry<String, List<ClassInfo>> shardEntry : parsedShard.entrySet()) {
          parsed
              // if the key didn't exist, create a new list
              .computeIfAbsent(shardEntry.getKey(), __ -> new ArrayList<>())
              .addAll(shardEntry.getValue()); // add values to the key
        }
      } catch (IOException e) {
        System.err.println("Error reading JSON file: " + jsonFile.getAbsolutePath());
        throw new RuntimeException(e);
      }
    }

    return parsed;
  }

  private static void getMethodRecords(
      ClassInfo clazz,
      String fullyQualifiedClassName,
      Map<String, MethodAnnotationsRecord> methodRecords) {
    for (MethodInfo method : clazz.methods()) {
      String methodName = method.name();
      // get return type nullness
      String returnType = removeGenericAnnotations(method.returnType());
      ImmutableSet<String> returnTypeNullness = ImmutableSet.of();
      // check if return type has Nullable annotation
      if (returnType.contains("@org.jspecify.annotations.Nullable")) {
        returnType = returnType.replace("@org.jspecify.annotations.Nullable ", "");
        returnType = returnType.replaceAll("@Nullable\\s*", "");
        returnType = returnType.replace(" []", "[]"); // remove whitespace in Array types
        returnTypeNullness = ImmutableSet.of("Nullable");
      }
      String signatureForMethodRecords = fullyQualifiedClassName + ":" + returnType + " ";
      signatureForMethodRecords += methodName.substring(0, methodName.indexOf('(') + 1);
      Map<Integer, ImmutableSet<String>> argAnnotation = new LinkedHashMap<>();

      // get the argument lists
      String[] argumentList = getArgumentsAsArray(methodName);

      for (int i = 0; i < argumentList.length; i++) {
        // remove generic annotations on arguments
        String typeSignature = removeGenericAnnotations(argumentList[i].trim());
        // remove annotations
        if (typeSignature.contains("@")) {
          String[] signatureTokens = typeSignature.split(" ");
          typeSignature = "";
          for (String token : signatureTokens) {
            if (token.contains("@")) {
              if (token.contains("@org.jspecify.annotations.Nullable")) {
                argAnnotation.put(i, ImmutableSet.of("Nullable"));
              }
              typeSignature += token.substring(0, token.indexOf('@'));
            } else {
              typeSignature += token;
            }
          }
        } else {
          // remove any spaces in Array types
          typeSignature = typeSignature.replace(" ", "");
        }
        argumentList[i] = typeSignature;
      }
      signatureForMethodRecords += String.join(", ", argumentList) + ")";
      methodRecords.put(
          signatureForMethodRecords,
          MethodAnnotationsRecord.create(returnTypeNullness, ImmutableMap.copyOf(argAnnotation)));
    }
  }

  private static String[] getArgumentsAsArray(String methodName) {
    // get String of only arguments
    String argsOnly = "";
    Pattern pattern = Pattern.compile(".*\\((.*)\\)");
    Matcher matcher = pattern.matcher(methodName);
    if (matcher.matches()) {
      argsOnly = matcher.group(1).trim();
    }

    if (argsOnly.isEmpty()) {
      return new String[0];
    }

    // make a list of arguments
    List<String> output = new ArrayList<>();
    StringBuilder cur = new StringBuilder();

    int depth = 0; // nesting level for '<' ... '>'
    for (int i = 0; i < argsOnly.length(); i++) {
      char c = argsOnly.charAt(i);
      switch (c) {
        case '<' -> {
          depth++;
          cur.append(c);
        }
        case '>' -> {
          depth = Math.max(0, depth - 1);
          cur.append(c);
        }
        case ',' -> {
          if (depth == 0) {
            String token = cur.toString().trim();
            if (!token.isEmpty()) {
              output.add(token);
            }
            cur.setLength(0);
          } else {
            cur.append(c);
          }
        }
        default -> cur.append(c);
      }
    }
    String tail = cur.toString().trim();
    if (!tail.isEmpty()) {
      output.add(tail);
    }

    return output.toArray(String[]::new);
  }

  private static String removeGenericAnnotations(String typeSignature) {
    if (typeSignature.indexOf('<') != -1) {
      StringBuilder withoutGenericAnnotations = new StringBuilder();
      int depth = 0;
      int annotationDepth = 0;
      for (int j = 0; j < typeSignature.length(); j++) {
        char ch = typeSignature.charAt(j);
        if (ch == '<') {
          depth++;
          withoutGenericAnnotations.append(ch);
        } else if (ch == '>') {
          depth = Math.max(0, depth - 1);
          withoutGenericAnnotations.append(ch);
        } else if (depth == 0) {
          withoutGenericAnnotations.append(ch);
        } else if (ch == '@') {
          annotationDepth++;
        } else if (ch == ' ' && annotationDepth != 0) {
          annotationDepth = Math.max(0, annotationDepth - 1);
        } else if (annotationDepth == 0) {
          withoutGenericAnnotations.append(ch);
        }
      }
      typeSignature = withoutGenericAnnotations.toString().trim();
    }
    return typeSignature;
  }
}
